import torch
import torch.nn as nn
from .base_embedding import BaseEmbedding

class RandomImageEmbedding(BaseEmbedding):
    def __init__(self,
                 num_embed=8192,
                 spatial_size=[32, 32], # height and with 
                 embed_dim=3968,     # 1024 by default
                 trainable=True,
                 pos_emb_type='embedding',
                 random_noise=True
        ):
        super().__init__()
        
        if isinstance(spatial_size, int):
            spatial_size = [spatial_size, spatial_size]

        self.spatial_size = spatial_size
        self.num_embed = num_embed
        self.embed_dim = embed_dim
        self.trainable = trainable
        self.pos_emb_type = pos_emb_type
        self.random_noise = random_noise

        assert self.pos_emb_type in ['embedding', 'parameter', 'None']
        
        # self.emb = nn.Embedding(self.num_embed, embed_dim)
        if self.pos_emb_type == 'embedding':
            self.height_emb = nn.Embedding(self.spatial_size[0], embed_dim) # height
            self.width_emb = nn.Embedding(self.spatial_size[1], embed_dim) # width
        elif self.pos_emb_type == 'parameter':
            self.height_emb = nn.Parameter(torch.zeros(1, self.spatial_size[0], embed_dim)) # height
            self.width_emb = nn.Parameter(torch.zeros(1, self.spatial_size[1], embed_dim)) # width
        else:
            pass
        self._set_trainable()

    def forward(self, index, **kwargs):
        """
        index: B x L, where L = prod(self.spatial_size)
        """
        assert index.dim() == 2 # B x L
        B, L = index.size()
        L = self.spatial_size[0] * self.spatial_size[1]
        emb = torch.zeros(B, L, self.embed_dim)

        if self.random_noise == True:
            emb += torch.randn(B, L, self.embed_dim)

        emb = emb.type(torch.cuda.FloatTensor)

        # try:
        #     index[index < 0] = 0  # some padded token maybe negative, so set them to 0
        #     emb = self.emb(index)
        # except:
        #     raise RuntimeError('IndexError: index out of range in self, max index {}, num embed {}'.format(index.max(), self.num_embed))
        
        # add col and row embedding
        if emb.shape[1] > 0 and self.pos_emb_type != 'None':
            if self.pos_emb_type == 'embedding':
                height_emb = self.height_emb(torch.arange(self.spatial_size[0], device=index.device).view(1, self.spatial_size[0])).unsqueeze(2) # 1 x H x D -> 1 x H x 1 x D
                width_emb = self.width_emb(torch.arange(self.spatial_size[1], device=index.device).view(1, self.spatial_size[1])).unsqueeze(1) # 1 x W x D -> 1 x 1 x W x D
            else:
                height_emb = self.height_emb.unsqueeze(2) # 1 x H x D -> 1 x H x 1 x D
                width_emb = self.width_emb.unsqueeze(1) # 1 x W x D -> 1 x 1 x W x D
            pos_emb = (height_emb + width_emb).view(1, self.spatial_size[0] * self.spatial_size[1], -1) # 1 x H x W x D -> 1 x L xD
            emb = emb + pos_emb[:, :emb.shape[1], :]

        return emb
